{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import transformers\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import datasets\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baichuan_13B_path = 'YOUR PATH HERE'\n",
    "baichuan_13B = AutoModelForCausalLM.from_pretrained(baichuan_13B_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(baichuan_13B_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INTERVAL = 1\n",
    "MERGE_LAYERS = 3\n",
    "HIGHEST_LAY = 39\n",
    "LOWEST_LAY = 0\n",
    "THRESHOLD = 0.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "def merge_layers_return_model(model, merge_base_lay, merge_layer_num):\n",
    "    merge_layer_num = min(merge_layer_num, len(model.model.layers) - merge_base_lay - 1)\n",
    "    \n",
    "    model_copy = deepcopy(model)\n",
    "    for diff_lay in range(merge_base_lay+1, merge_base_lay+1+merge_layer_num):      \n",
    "        # gate_proj\n",
    "        model_copy.model.layers[merge_base_lay].mlp.gate_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].mlp.gate_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.gate_proj.weight.data\n",
    "        )\n",
    "        # down_proj\n",
    "        model_copy.model.layers[merge_base_lay].mlp.down_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].mlp.down_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.down_proj.weight.data\n",
    "        )\n",
    "        # up_proj\n",
    "        model_copy.model.layers[merge_base_lay].mlp.up_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].mlp.up_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.up_proj.weight.data\n",
    "        )\n",
    "        # W_pack\n",
    "        model_copy.model.layers[merge_base_lay].self_attn.W_pack.weight.data.add_(\n",
    "            model.model.layers[diff_lay].self_attn.W_pack.weight.data - model_copy.model.layers[merge_base_lay].self_attn.W_pack.weight.data\n",
    "        )\n",
    "        # o_proj\n",
    "        model_copy.model.layers[merge_base_lay].self_attn.o_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].self_attn.o_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.o_proj.weight.data\n",
    "        )\n",
    "    for diff_lay in range(merge_base_lay+merge_layer_num, merge_base_lay, -1):\n",
    "        del(model_copy.model.layers[diff_lay])\n",
    "    return model_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "baichuan_13B_copy_to_compress = copy.deepcopy(baichuan_13B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def cal_last_hidden_sim(model1, model2, tokenizer, sents):\n",
    "    sim_ls = []\n",
    "    for s in sents:\n",
    "        encoded_inputs = tokenizer(s, return_tensors='pt')\n",
    "        with torch.no_grad():\n",
    "            outputs1 = model1(**encoded_inputs, output_hidden_states=True)\n",
    "        hidden_states1 = outputs1.hidden_states[-1] # (1, seq_len, hidden)\n",
    "        with torch.no_grad():\n",
    "            outputs2 = model2(**encoded_inputs, output_hidden_states=True)\n",
    "        hidden_states2 = outputs2.hidden_states[-1] # (1, seq_len, hidden)\n",
    "        sim_ls.append(torch.cosine_similarity(hidden_states1.squeeze(0).flatten().unsqueeze(0), hidden_states2.squeeze(0).flatten().unsqueeze(0)))\n",
    "    sim_ls = [i.item() for i in sim_ls]\n",
    "    print(np.mean(sim_ls), sim_ls)\n",
    "    \n",
    "    return np.mean(sim_ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lay = HIGHEST_LAY - MERGE_LAYERS\n",
    "last_merge_flag = False\n",
    "\n",
    "sents= ['Torreorgaz is a municipality in the', 'The 81st Mechanised Brigade () is a']\n",
    "\n",
    "while lay >= LOWEST_LAY:\n",
    "    print(lay)\n",
    "    print('current model layer', len(baichuan_13B_copy_to_compress.model.layers))\n",
    "    tmp_merged_model = merge_layers_return_model(baichuan_13B_copy_to_compress, lay, MERGE_LAYERS-1)\n",
    "    sim_value = cal_last_hidden_sim(baichuan_13B, tmp_merged_model, tokenizer, sents)\n",
    "    if sim_value > THRESHOLD:\n",
    "        baichuan_13B_copy_to_compress = tmp_merged_model\n",
    "        lay -= INTERVAL\n",
    "        if lay >= len(baichuan_13B_copy_to_compress.model.layers):\n",
    "            lay = len(baichuan_13B_copy_to_compress.model.layers) - 1 - MERGE_LAYERS\n",
    "    else:\n",
    "        lay -= 1\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baichuan_13B_copy_to_compress.config.num_hidden_layers = len(baichuan_13B_copy_to_compress.model.layers)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
